import torch
import torchvision
import time
import copy

def train_model(device, model, dataloaders, criterion, optimizer, scheduler, num_epochs=25):
    since = time.time()
    val_acc_history = []

    #make a copy of the model to keep track of the best performing model
    best_model_wts = copy.deepcopy(model.state_dict())
    best_acc = 0.0

    for epoch in range(num_epochs):
        print('Epoch {}/{}'.format(epoch, num_epochs - 1))
        print('-' * 10)

        #training and validation phase in each epoch
        for phase in ['train', 'val']:
            if phase == 'train':
                model.train()  #set model to training mode
            else:
                scheduler.step()
                model.eval()   #set model to evaluation mode

            running_loss = 0.0
            running_corrects = 0
            corrects = 0
            
            #iterate over data
            for inputs, labels, answers in dataloaders[phase]:
                batch = inputs.shape[0]
                #change shape of the tensors to present the network with separate rows
                inputs = inputs.view(batch*10, 3, 224, 224)
                labels = labels.view(batch*10,1)
                #GPU computations
                inputs = inputs.to(device)
                labels = labels.to(device)
                answers = answers.to(device)

                #zero the parameter gradients
                optimizer.zero_grad()

                #forward pass
                #track history only in training phase
                with torch.set_grad_enabled(phase == 'train'):
                    #get model outputs and calculate loss
                    outputs = normalize(model(inputs)) #apply sigmoid
                    loss = criterion(outputs, labels)
                    
                    outputs = outputs.view(batch,10)
                    preds = torch.round(outputs)
                    labels = labels.view(batch,10)

                    #backward + optimize only in training phase
                    if phase == 'train':
                        loss.backward()
                        optimizer.step()


                #statistics
                running_loss += loss.item()
                running_corrects += torch.sum(preds.data==labels.data)
                corrects += torch.sum(answers.data==twohot_decode(outputs))
            
            epoch_loss = running_loss / (len(dataloaders[phase].dataset)*10)
            epoch_acc = running_corrects.double() / (len(dataloaders[phase].dataset)*10) * 100
            real_acc = corrects.double() / len(dataloaders[phase].dataset) * 100

            print('{} Loss: {:.4f} Acc: {:.4f}%'.format(phase, epoch_loss, epoch_acc))
            print('Real Acc: {:.4f}%'.format(real_acc))
            
            #deep copy the model if better
            if phase == 'val' and real_acc > best_acc:
                best_acc = real_acc
                best_model_wts = copy.deepcopy(model.state_dict())
            if phase == 'val':
                val_acc_history.append(real_acc.item())

        print()

    time_elapsed = time.time() - since #track the training time
    print('Training complete in {:.0f}m {:.0f}s'.format(time_elapsed // 60, time_elapsed % 60))
    print('Best val Acc: {:4f}'.format(best_acc))

    #load best model weights and return the best model
    model.load_state_dict(best_model_wts)
    return model, val_acc_history

#create pseudo target
def get_target():
    return torch.tensor([1,1,0,0,0,0,0,0,0,0], dtype=torch.float32)

#cut the first two values of the prediction and select the candidate with the highest predicted value
def twohot_decode(X):
    return torch.argmax(torch.narrow(X, 1, 2, 8), dim=1)

def normalize(X):
    return torch.sigmoid(X)

#freeze batch normalization layers of ResNet-18
def freeze_BatchNorm2d(model):
    for name, child in (model.named_children()):
        if isinstance(child, torch.nn.BatchNorm2d):
            for param in child.parameters():
                param.requires_grad = False
        else:
            if isinstance(child, torch.nn.Sequential):
                for seqname, seqchild in (child.named_children()):
                    if isinstance(seqchild, torchvision.models.resnet.BasicBlock):
                        for mininame, minichild in (seqchild.named_children()):
                            if isinstance(minichild, torch.nn.BatchNorm2d):
                                for param in minichild.parameters():
                                    param.requires_grad = False
    return

def test_model(device, model, dataloaders):
    since = time.time()

    #load model (not necessary when testing directly after training)
    state = torch.load('all_unsup.pt', device)
    model.load_state_dict(state)
    model.to(device)

    print('Testing...')
    print('-' * 10)

    #only testing phase
    phase = 'test'
    model.eval()   #set model to evaluation mode
    corrects = 0
    
    #iterate over data
    for inputs, labels, answers in dataloaders[phase]:
        batch = inputs.shape[0]
        inputs = inputs.view(batch*10, 3, 224, 224)
        inputs = inputs.to(device)
        answers = answers.to(device)

        #forward pass
        #does not track history
        with torch.set_grad_enabled(phase == 'train'):
            #get model outputs and calculate loss
            outputs = normalize(model(inputs))

            preds = twohot_decode(outputs.view(batch,10))
            
            
        #statistics
        corrects += torch.sum(preds.data==answers.data)

    real_acc = corrects.double() / len(dataloaders[phase].dataset) * 100
    print('Acc: {:.4f}%'.format(real_acc))

    print()

    time_elapsed = time.time() - since #track testing time
    print('Test complete in {:.0f}m {:.0f}s'.format(time_elapsed // 60, time_elapsed % 60))

    return
